Feat/selective offload on srelu fuser#3047
Conversation
dba3531 to
53e6511
Compare
Greptile SummaryThis PR adds selective CPU activation offloading to the SReLU fused op (
Confidence Score: 4/5The selective offloading logic is functional for FC1 input and activation tensors; the main gap is that the FC2 GEMM input has no corresponding selective-offload control and is always pinned to GPU. The core mechanism works correctly end-to-end. The FC2 input is unconditionally prevented from offloading with no API to override this, creating an asymmetry that may be an oversight. The non-V1 path uses prepare_for_saving destructively on GroupedTensorStorage objects but is safe here because quantized_tensors is always None at the point of the call. forward_grouped_mlp.py around the FC2 mark_not_offload block at lines 596-599 where the missing fine_grained_activation_offloading gate lives Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Forward fused op] --> B[Construct GroupedTensorStorage for fc1_x and fc2_x]
B --> C{cpu_offloading enabled?}
C -- No --> G[save_for_backward via OperationContext]
C -- Yes --> D{offload_fc1_input?}
D -- False --> E[mark_not_offload grouped_fc1_x]
D -- True --> F[fc1_x eligible for offload]
E --> H[mark_not_offload fc1 weights always]
F --> H
H --> I{offload_activation_input?}
I -- False --> J[mark_not_offload activation_in and scales]
I -- True --> K[activation tensors eligible for offload]
J --> L[mark_not_offload saved_grouped_fc2_x always]
K --> L
L --> M[mark_not_offload fc2 weights always]
M --> G
G --> N[fuser.py prepare_for_saving decomposes objects]
N --> O[PyTorch save_for_backward push_tensor hook]
O --> P{_TE_do_not_offload set?}
P -- Yes --> Q[Stay on GPU]
P -- No --> R[Offload to CPU]
Reviews (4): Last reviewed commit: "Simplify fused grouped MLP offload check..." | Re-trigger Greptile |
| selective_offload = hasattr(fc1_op, "activation_offloading") or hasattr( | ||
| activation_op, "activation_offloading" | ||
| ) | ||
| offload_fc1_input = bool(getattr(fc1_op, "activation_offloading", False)) | ||
| offload_activation_input = bool(getattr(activation_op, "activation_offloading", False)) |
There was a problem hiding this comment.
Selective-offload gate never activates unless callers set
activation_offloading on op objects
hasattr(fc1_op, "activation_offloading") checks for a dynamic attribute on the op module. mark_activation_offload (both V1 and non-V1) sets activation_offloading on tensors, not on op instances, and neither GroupedLinear nor ScaledSReLU declare this attribute. As written, selective_offload will always be False and none of the new marking logic will execute unless callers set fc1_op.activation_offloading = True externally. If this is intentional, the attribute name, type, and expected caller pattern should be documented; if not, the gate condition needs to match how the attribute is actually assigned.
2c59510 to
6e01d0a
Compare
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Overall looks good, but with one design suggestion.
Followup tasks after merging this PR:
- Enable activation checkpointing in the unfused grouped linear op.
- Update activation checkpointing to support v2 infrastructure from #1762, which is opt-out rather than opt-in.
| offload_fc1_input = bool(getattr(fc1_op, "fine_grained_activation_offloading", False)) | ||
| offload_activation_input = bool( | ||
| getattr(activation_op, "fine_grained_activation_offloading", False) | ||
| ) |
There was a problem hiding this comment.
- Do these options give us value? The dense linear op and activation ops don't expose this fine-grained control:
TransformerEngine/transformer_engine/pytorch/ops/basic/basic_linear.py
Lines 1052 to 1053 in ace2a96
TransformerEngine/transformer_engine/pytorch/ops/basic/activation.py
Lines 116 to 117 in ace2a96
- For consistency with the rest of the CPU offloading behavior, shouldn't the default be to enable offloading? Disabling offloading should be the explicit path.
- These secret undocumented attrs are delicate and unmaintainable. Better to make them arguments in the unfused ops. However, this also means we should update the unfused impls so that they disable activation checkpointing if the option is set.
Easiest just to not make this configurable.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: